{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from bigbird import modeling\n",
    "from bigbird import utils\n",
    "import tensorflow as tf\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_config = {\n",
    "  # transformer basic configs\n",
    "  \"attention_probs_dropout_prob\": 0.1,\n",
    "  \"hidden_act\": 'relu',\n",
    "  \"hidden_dropout_prob\": 0.1,\n",
    "  \"hidden_size\": 768,\n",
    "  \"initializer_range\": 0.02,\n",
    "  \"intermediate_size\": 3072,\n",
    "  \"max_position_embeddings\": 4096,\n",
    "  \"max_encoder_length\": 2048,\n",
    "  \"max_decoder_length\": 512,\n",
    "  \"num_attention_heads\": 12,\n",
    "  \"num_hidden_layers\": 12,\n",
    "  \"type_vocab_size\": 2,\n",
    "  \"scope\": 'pegasus',\n",
    "  \"use_bias\": False,\n",
    "  \"rescale_embedding\": True,\n",
    "  \"vocab_model_file\": None,\n",
    "  # sparse mask configs\n",
    "  \"attention_type\": \"block_sparse\",\n",
    "  \"norm_type\": 'prenorm',\n",
    "  \"block_size\": 64,\n",
    "  \"num_rand_blocks\": 3,\n",
    "  \"vocab_size\": 32000,\n",
    "  \"beam_size\": 1,\n",
    "  \"alpha\": 0.0,\n",
    "  \"couple_encoder_decoder\": False\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = modeling.TransformerModel(bert_config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "X = tf.placeholder(tf.int32, [None, None])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /Users/huseinzolkepli/Documents/Malaya/pretrained-model/bigbird/bigbird/modeling.py:226: where (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use tf.where in 2.0, which has the same broadcast rule as np.where\n",
      "Tensor(\"pegasus/strided_slice_1:0\", shape=(), dtype=int32)\n",
      "WARNING:tensorflow:From /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow_core/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "reduction_indices is deprecated, use axis instead\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/tensorflow_core/python/util/deprecation.py:507: calling count_nonzero (from tensorflow.python.ops.math_ops) with axis is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "reduction_indices is deprecated, use axis instead\n"
     ]
    }
   ],
   "source": [
    "r = model(X, training = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_ids = r[0][2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess = tf.InteractiveSession()\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "o = sess.run(pred_ids, feed_dict = {X: [[1] * 2048]})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
